package kvcache

import (
	"fmt"
	"math"
	"slices"
	"testing"

	"github.com/ollama/ollama/ml"
	"github.com/ollama/ollama/model/input"
)

type testCase struct {
	name          string
	in            []float32
	inShape       []int
	seqs          []int
	pos           []int32
	expected      []float32
	expectedShape []int
	expectedMask  []float32
}

func runPermutedVariants(t *testing.T, fn func(t *testing.T, backend *testBackend)) {
	t.Helper()
	for _, permuted := range []bool{false, true} {
		t.Run(fmt.Sprintf("PermutedV=%t", permuted), func(t *testing.T) {
			fn(t, &testBackend{permutedV: permuted})
		})
	}
}

func TestStore(t *testing.T) {
	runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
		cache := NewCausalCache(nil)
		defer cache.Close()

		cache.Init(backend, ml.DTypeF16, 1, 16, 16)

		tests := []testCase{
			{
				name:          "FirstBatch",
				in:            []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
				inShape:       []int{2, 3, 4},
				seqs:          []int{0, 0, 0, 0},
				pos:           []int32{0, 1, 2, 3},
				expected:      []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
				expectedShape: []int{2, 3, 4},
				expectedMask:  []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
			},
			{
				name:          "SecondBatch",
				in:            []float32{115, 215, 125, 225, 135, 235},
				inShape:       []int{2, 3, 1},
				seqs:          []int{0},
				pos:           []int32{4},
				expected:      []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234, 115, 215, 125, 225, 135, 235},
				expectedShape: []int{2, 3, 5},
				expectedMask:  []float32{0, 0, 0, 0, 0},
			},
		}

		testCache(t, backend, cache, tests)
	})
}

func TestSWA(t *testing.T) {
	runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
		cache := NewSWACache(1, nil)
		defer cache.Close()

		cache.Init(backend, ml.DTypeF16, 1, 16, 16)

		x := float32(math.Inf(-1))

		tests := []testCase{
			{
				name:          "FirstBatch",
				in:            []float32{1, 2, 3, 4},
				inShape:       []int{1, 1, 4},
				seqs:          []int{0, 0, 0, 0},
				pos:           []int32{0, 1, 2, 3},
				expected:      []float32{1, 2, 3, 4},
				expectedShape: []int{1, 1, 4},
				expectedMask: []float32{
					0, x, x, x,
					0, 0, x, x,
					x, 0, 0, x,
					x, x, 0, 0,
				},
			},
			{
				name:          "SecondBatch",
				in:            []float32{5, 6},
				inShape:       []int{1, 1, 2},
				seqs:          []int{0, 0},
				pos:           []int32{4, 5},
				expected:      []float32{5, 6, 3, 4},
				expectedShape: []int{1, 1, 4},
				expectedMask: []float32{
					0, x, x, 0,
					0, 0, x, x,
				},
			},
		}

		testCache(t, backend, cache, tests)
	})
}

func TestSWASeparateBatches(t *testing.T) {
	runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
		cache := NewSWACache(1, nil)
		defer cache.Close()

		cache.Init(backend, ml.DTypeF16, 2, 16, 2)

		x := float32(math.Inf(-1))

		tests := []testCase{
			{
				name:          "First seq 0",
				in:            []float32{1, 2},
				inShape:       []int{1, 1, 2},
				seqs:          []int{0, 0},
				pos:           []int32{0, 1},
				expected:      []float32{1, 2},
				expectedShape: []int{1, 1, 2},
				expectedMask: []float32{
					0, x,
					0, 0,
				},
			},
			{
				name:          "Second seq 0",
				in:            []float32{3, 4},
				inShape:       []int{1, 1, 2},
				seqs:          []int{0, 0},
				pos:           []int32{2, 3},
				expected:      []float32{2, 3, 4},
				expectedShape: []int{1, 1, 3},
				expectedMask: []float32{
					0, 0, x,
					x, 0, 0,
				},
			},
			{
				name:          "First seq 1",
				in:            []float32{5, 6},
				inShape:       []int{1, 1, 2},
				seqs:          []int{1, 1},
				pos:           []int32{0, 1},
				expected:      []float32{5, 6},
				expectedShape: []int{1, 1, 2},
				expectedMask: []float32{
					0, x,
					0, 0,
				},
			},
			{
				name:          "Second seq 1",
				in:            []float32{7, 8},
				inShape:       []int{1, 1, 2},
				seqs:          []int{1, 1},
				pos:           []int32{2, 3},
				expected:      []float32{6, 3, 4, 7, 8},
				expectedShape: []int{1, 1, 5},
				expectedMask: []float32{
					0, x, x, 0, x,
					x, x, x, 0, 0,
				},
			},
			{
				name:          "Third seq 0",
				in:            []float32{9, 10},
				inShape:       []int{1, 1, 2},
				seqs:          []int{0, 0},
				pos:           []int32{4, 5},
				expected:      []float32{9, 10, 3, 4},
				expectedShape: []int{1, 1, 4},
				expectedMask: []float32{
					0, x, x, 0,
					0, 0, x, x,
				},
			},
		}

		testCache(t, backend, cache, tests)
	})
}

func TestSWAMem(t *testing.T) {
	runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
		cache := NewSWAMemCache(1, 3, nil)
		defer cache.Close()

		cache.Init(backend, ml.DTypeF16, 1, 16, 16)

		x := float32(math.Inf(-1))

		tests := []testCase{
			{
				name:          "FirstBatch",
				in:            []float32{1, 2, 3, 4},
				inShape:       []int{1, 1, 4},
				seqs:          []int{0, 0, 0, 0},
				pos:           []int32{0, 1, 2, 3},
				expected:      []float32{1, 2, 3, 4},
				expectedShape: []int{1, 1, 4},
				expectedMask: []float32{
					0, x, x, x,
					0, 0, x, x,
					x, 0, 0, x,
					x, x, 0, 0,
				},
			},
			{
				name:          "SecondBatch",
				in:            []float32{5, 6},
				inShape:       []int{1, 1, 2},
				seqs:          []int{0, 0},
				pos:           []int32{4, 5},
				expected:      []float32{5, 2, 3, 4, 6},
				expectedShape: []int{1, 1, 5},
				expectedMask: []float32{
					0, x, x, 0, x,
					0, x, x, x, 0,
				},
			},
		}

		testCache(t, backend, cache, tests)
	})
}

func TestChunkedAttention(t *testing.T) {
	runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
		cache := NewChunkedAttentionCache(2, nil)
		defer cache.Close()

		cache.Init(backend, ml.DTypeF16, 1, 16, 16)

		x := float32(math.Inf(-1))

		testCache(
			t, backend, cache,
			[]testCase{
				{
					name:          "FirstBatch",
					in:            []float32{1, 2, 3, 4},
					inShape:       []int{1, 1, 4},
					seqs:          []int{0, 0, 0, 0},
					pos:           []int32{0, 1, 2, 3},
					expected:      []float32{1, 2, 3, 4},
					expectedShape: []int{1, 1, 4},
					expectedMask: []float32{
						0, x, x, x,
						0, 0, x, x,
						x, x, 0, x,
						x, x, 0, 0,
					},
				},
				{
					name:          "SecondBatch",
					in:            []float32{5, 6, 7},
					inShape:       []int{1, 1, 3},
					seqs:          []int{0, 0, 0},
					pos:           []int32{4, 5, 6},
					expected:      []float32{1, 2, 3, 4, 5, 6, 7},
					expectedShape: []int{1, 1, 7},
					expectedMask: []float32{
						x, x, x, x, 0, x, x,
						x, x, x, x, 0, 0, x,
						x, x, x, x, x, x, 0,
					},
				},
				{
					name:          "ThirdBatch",
					in:            []float32{8, 9},
					inShape:       []int{1, 1, 2},
					seqs:          []int{0, 0},
					pos:           []int32{7, 8},
					expected:      []float32{1, 2, 3, 4, 5, 6, 7, 8, 9},
					expectedShape: []int{1, 1, 9},
					expectedMask: []float32{
						x, x, x, x, x, x, 0, 0, x,
						x, x, x, x, x, x, x, x, 0,
					},
				},
			},
		)
	})
}

func TestSequences(t *testing.T) {
	runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
		cache := NewCausalCache(nil)
		defer cache.Close()

		cache.Init(backend, ml.DTypeF16, 1, 16, 16)

		tests := []testCase{
			{
				name:          "FirstBatch",
				in:            []float32{1, 2, 3, 4},
				inShape:       []int{1, 1, 4},
				seqs:          []int{0, 0, 1, 1},
				pos:           []int32{0, 1, 0, 1},
				expected:      []float32{1, 2, 3, 4},
				expectedShape: []int{1, 1, 4},
				expectedMask:  []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
			},
			{
				name:          "SecondBatch",
				in:            []float32{5, 6},
				inShape:       []int{1, 1, 2},
				seqs:          []int{0, 1},
				pos:           []int32{2, 2},
				expected:      []float32{1, 2, 3, 4, 5, 6},
				expectedShape: []int{1, 1, 6},
				expectedMask:  []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0},
			},
		}

		testCache(t, backend, cache, tests)
	})
}

func TestRemove(t *testing.T) {
	runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
		cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
			return key.Add(ctx, shift), nil
		})
		defer cache.Close()

		cache.Init(backend, ml.DTypeF16, 1, 16, 16)

		x := float32(math.Inf(-1))

		tests := []testCase{
			{
				name:          "FirstBatch",
				in:            []float32{1, 2, 3, 4},
				inShape:       []int{1, 1, 4},
				seqs:          []int{0, 0, 1, 1},
				pos:           []int32{0, 1, 0, 1},
				expected:      []float32{1, 2, 3, 4},
				expectedShape: []int{1, 1, 4},
				expectedMask: []float32{
					0, x, x, x,
					0, 0, x, x,
					x, x, 0, x,
					x, x, 0, 0,
				},
			},
		}

		testCache(t, backend, cache, tests)

		err := cache.Remove(0, 1, math.MaxInt32)
		if err != nil {
			panic(err)
		}

		tests = []testCase{
			{
				name:          "RemoveEnd",
				in:            []float32{5, 6},
				inShape:       []int{1, 1, 2},
				seqs:          []int{0, 1},
				pos:           []int32{1, 2},
				expected:      []float32{1, 5, 3, 4, 6},
				expectedShape: []int{1, 1, 5},
				expectedMask: []float32{
					0, 0, x, x, x,
					x, x, 0, 0, 0,
				},
			},
		}

		testCache(t, backend, cache, tests)

		err = cache.Remove(0, 0, 1)
		if err != nil {
			panic(err)
		}

		tests = []testCase{
			{
				name:          "RemoveMiddle",
				in:            []float32{7, 8},
				inShape:       []int{1, 1, 2},
				seqs:          []int{0, 0},
				pos:           []int32{1, 2},
				expected:      []float32{7, 4, 3, 4, 6, 8},
				expectedShape: []int{1, 1, 6},
				expectedMask: []float32{
					0, 0, x, x, x, x,
					0, 0, x, x, x, 0,
				},
			},
		}

		testCache(t, backend, cache, tests)
	})
}

func TestCopy(t *testing.T) {
	runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
		cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil })
		defer cache.Close()

		cache.Init(backend, ml.DTypeF16, 1, 16, 16)

		tests := []testCase{
			{
				name:          "FirstBatch",
				in:            []float32{1, 2, 3, 4},
				inShape:       []int{1, 1, 4},
				seqs:          []int{0, 0, 0, 0},
				pos:           []int32{0, 1, 2, 3},
				expected:      []float32{1, 2, 3, 4},
				expectedShape: []int{1, 1, 4},
				expectedMask:  []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
			},
		}

		testCache(t, backend, cache, tests)

		cache.CopyPrefix(0, 1, 2)

		tests = []testCase{
			{
				name:          "Copy",
				in:            []float32{5, 6},
				inShape:       []int{1, 1, 2},
				seqs:          []int{1, 1},
				pos:           []int32{3, 4},
				expected:      []float32{1, 2, 3, 4, 5, 6},
				expectedShape: []int{1, 1, 6},
				expectedMask:  []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
			},
		}

		testCache(t, backend, cache, tests)
	})
}

func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) {
	for _, test := range tests {
		t.Run(test.name, func(t *testing.T) {
			context := backend.NewContext()
			defer context.Close()

			err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs}, false)
			if err != nil {
				panic(err)
			}

			cache.SetLayer(0)
			tensor := context.FromFloats(test.in, test.inShape...)
			cache.Put(context, tensor, tensor)

			out, _, mask := cache.Get(context)

			context.Forward(out, mask).Compute(out, mask)

			if !slices.Equal(out.Floats(), test.expected) {
				t.Errorf("TestCache: have %v; want %v", out.Floats(), test.expected)
			}

			if !slices.Equal(out.Shape(), test.expectedShape) {
				t.Errorf("TestCache: has shape %v; want %v", out.Shape(), test.expectedShape)
			}

			if !slices.Equal(mask.Floats(), test.expectedMask) {
				t.Errorf("TestCache: have mask: have %v want %v", mask.Floats(), test.expectedMask)
			}
		})
	}
}

func TestCanResume(t *testing.T) {
	runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
		windowSize := int32(4)
		cache := NewSWACache(windowSize, nil)
		defer cache.Close()

		cache.Init(backend, ml.DTypeF16, 1, 16, 16)

		context := backend.NewContext()
		defer context.Close()

		err := cache.StartForward(context, input.Batch{
			Positions: []int32{0, 1, 2, 3, 4},
			Sequences: []int{0, 0, 0, 0, 0},
		}, false)
		if err != nil {
			t.Fatalf("StartForward failed: %v", err)
		}

		cache.SetLayer(0)
		tensor := context.FromFloats([]float32{1, 2, 3, 4, 5}, 1, 1, 5)
		cache.Put(context, tensor, tensor)

		// with window size 4, nothing has slid out of the window yet
		if !cache.CanResume(0, 0) {
			t.Errorf("CanResume(0, 0) = false, want true (within window)")
		}
		if !cache.CanResume(0, 1) {
			t.Errorf("CanResume(0, 1) = false, want true (within window)")
		}
		if !cache.CanResume(0, 2) {
			t.Errorf("CanResume(0, 2) = false, want true (within window)")
		}
		if !cache.CanResume(0, 3) {
			t.Errorf("CanResume(0, 3) = false, want true (latest position)")
		}
		if !cache.CanResume(0, 4) {
			t.Errorf("CanResume(0, 4) = false, want true (latest position)")
		}

		// shift window by adding position 5
		err = cache.StartForward(context, input.Batch{
			Positions: []int32{5},
			Sequences: []int{0},
		}, false)
		if err != nil {
			t.Fatalf("StartForward failed: %v", err)
		}

		cache.SetLayer(0)
		tensor = context.FromFloats([]float32{6}, 1, 1, 1)
		cache.Put(context, tensor, tensor)

		// only the latest position has overlapping windows
		if cache.CanResume(0, 0) {
			t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)")
		}
		if cache.CanResume(0, 1) {
			t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)")
		}
		if cache.CanResume(0, 2) {
			t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)")
		}
		if cache.CanResume(0, 3) {
			t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)")
		}
		if cache.CanResume(0, 4) {
			t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)")
		}
		if !cache.CanResume(0, 5) {
			t.Errorf("after shift: CanResume(0, 5) = false, want true (latest position)")
		}
	})
}

func TestCanResumeSWAMem(t *testing.T) {
	runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
		windowSize := int32(4)
		memSize := int32(5)
		cache := NewSWAMemCache(windowSize, memSize, nil)
		defer cache.Close()

		cache.Init(backend, ml.DTypeF16, 1, 16, 16)

		context := backend.NewContext()
		defer context.Close()

		err := cache.StartForward(context, input.Batch{
			Positions: []int32{0, 1, 2, 3, 4, 5, 6},
			Sequences: []int{0, 0, 0, 0, 0, 0, 0},
		}, false)
		if err != nil {
			t.Fatalf("StartForward failed: %v", err)
		}

		cache.SetLayer(0)
		tensor := context.FromFloats([]float32{1, 2, 3, 4, 5, 6, 7}, 1, 1, 7)
		cache.Put(context, tensor, tensor)

		// shift window by adding position 7
		err = cache.StartForward(context, input.Batch{
			Positions: []int32{7},
			Sequences: []int{0},
		}, false)
		if err != nil {
			t.Fatalf("StartForward failed: %v", err)
		}

		cache.SetLayer(0)
		tensor = context.FromFloats([]float32{8}, 1, 1, 1)
		cache.Put(context, tensor, tensor)

		// only the latest position has overlapping windows
		if cache.CanResume(0, 0) {
			t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)")
		}
		if cache.CanResume(0, 1) {
			t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)")
		}
		if cache.CanResume(0, 2) {
			t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)")
		}
		if cache.CanResume(0, 3) {
			t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)")
		}
		if cache.CanResume(0, 4) {
			t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)")
		}
		if cache.CanResume(0, 5) {
			t.Errorf("after shift: CanResume(0, 5) = true, want false (outside window)")
		}
		if !cache.CanResume(0, 6) {
			t.Errorf("after shift: CanResume(0, 6) = false, want true (inside window)")
		}
		if !cache.CanResume(0, 7) {
			t.Errorf("after shift: CanResume(0, 7) = false, want true (latest position)")
		}
	})
}

type testBackend struct {
	ml.Backend
	permutedV bool
}

func (b *testBackend) NewContext() ml.Context {
	return &testContext{}
}

func (b *testBackend) NewContextSize(int) ml.Context {
	return &testContext{}
}

func (b *testBackend) CacheConfig() ml.CacheConfig {
	return ml.CacheConfig{PermutedV: b.permutedV}
}

type testContext struct {
	ml.Context
}

func (c *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor {
	total := 0

	if len(shape) > 0 {
		total = 1
		for _, s := range shape {
			total *= s
		}
	}

	return &testTensor{dtype: dtype, elementSize: 4, data: make([]float32, total), shape: shape}
}

func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
	return c.Empty(dtype, shape...)
}

func (c *testContext) FromFloats(s []float32, shape ...int) ml.Tensor {
	t := c.Empty(ml.DTypeF32, shape...).(*testTensor)

	copy(t.data, s)

	return t
}

func (c *testContext) FromInts(s []int32, shape ...int) ml.Tensor {
	f := make([]float32, len(s))
	for i := range f {
		f[i] = float32(s[i])
	}

	out := c.FromFloats(f, shape...)
	out.(*testTensor).dtype = ml.DTypeI32

	return out
}

func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
	s := make([]float32, 0, int((stop-start)/step))
	for i := start; i < stop; i += step {
		s = append(s, i)
	}

	out := c.FromFloats(s, len(s))
	out.(*testTensor).dtype = dtype
	return out
}

func (c *testContext) Input() ml.Context    { return c }
func (c *testContext) Layer(int) ml.Context { return c }

func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }

func (c *testContext) Compute(...ml.Tensor) {}

func (c *testContext) Reserve() {}

func (c *testContext) MaxGraphNodes() int {
	return 10
}

func (c *testContext) Close() {}

type testTensor struct {
	ml.Tensor

	dtype       ml.DType
	elementSize int
	data        []float32
	shape       []int
}

func (t *testTensor) Dim(n int) int {
	return t.shape[n]
}

func (t *testTensor) Stride(n int) int {
	stride := t.elementSize
	for i := range n {
		stride *= t.shape[i]
	}

	return stride
}

func (t *testTensor) Shape() []int {
	return t.shape
}

func (t *testTensor) DType() ml.DType {
	return t.dtype
}

func (t *testTensor) Floats() []float32 {
	out := make([]float32, len(t.data))
	copy(out, t.data)
	return out
}

func (t *testTensor) Neg(ctx ml.Context) ml.Tensor {
	out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
	for i := range out.data {
		out.data[i] = -t.data[i]
	}
	return out
}

func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
	out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)

	for i := range out.data {
		out.data[i] = t.data[i] + t2.(*testTensor).data[i]
	}

	return out
}

func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
	return &testTensor{
		dtype:       t.dtype,
		elementSize: t.elementSize,
		data:        t.data,
		shape:       shape,
	}
}

func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
	offset /= t.elementSize

	var s []int

	switch len(shape) {
	case 1:
		s = []int{shape[0]}
	case 3:
		s = []int{shape[0], shape[2]}
	case 5:
		s = []int{shape[0], shape[2], shape[4]}
	default:
		panic("unsupported number of dimensions")
	}

	context := &testContext{}

	view := context.Empty(t.dtype, s...).(*testTensor)
	view.data = t.data[offset : offset+len(view.data)]

	return view
}

func (t *testTensor) Permute(ctx ml.Context, order ...int) ml.Tensor {
	if len(t.shape) > 4 || len(order) > 4 {
		panic("permute only supports up to 4 dimensions")
	}

	if len(order) != len(t.shape) && len(order) != 4 {
		panic("invalid number of dimensions for permute")
	}

	// ggml_permute expects 4 axes, so fill in any missing dimensions.
	orderFull := append(make([]int, 0, 4), order...)
	for len(orderFull) < 4 {
		orderFull = append(orderFull, len(orderFull))
	}

	seen := [4]bool{}

	shape4 := [4]int{1, 1, 1, 1}
	for i := 0; i < len(t.shape) && i < 4; i++ {
		shape4[i] = t.shape[i]
	}

	newShape4 := [4]int{1, 1, 1, 1}
	for axis := range 4 {
		dst := orderFull[axis]
		if dst < 0 || dst >= 4 {
			panic("invalid axis for permute")
		}
		if seen[dst] {
			panic("duplicate axis for permute")
		}
		seen[dst] = true
		newShape4[dst] = shape4[axis]
	}

	total := len(t.data)
	newData := make([]float32, total)

	if total > 0 {
		oldDims := shape4
		newDims := newShape4

		oldStride := [4]int{1, 1, 1, 1}
		newStride := [4]int{1, 1, 1, 1}
		for i := 1; i < 4; i++ {
			oldStride[i] = oldStride[i-1] * oldDims[i-1]
			newStride[i] = newStride[i-1] * newDims[i-1]
		}

		var coords [4]int
		var newCoords [4]int

		for idx := range total {
			remainder := idx
			for axis := range 4 {
				dim := oldDims[axis]
				if dim == 0 {
					coords[axis] = 0
					continue
				}
				coords[axis] = remainder % dim
				remainder /= dim
			}

			for axis := range 4 {
				newCoords[orderFull[axis]] = coords[axis]
			}

			newIndex := 0
			for axis := range 4 {
				if newDims[axis] == 0 {
					continue
				}
				newIndex += newCoords[axis] * newStride[axis]
			}

			newData[newIndex] = t.data[idx]
		}
	}

	numDims := 4
	for numDims > 1 && newShape4[numDims-1] <= 1 {
		numDims--
	}

	newShape := make([]int, numDims)
	copy(newShape, newShape4[:numDims])

	return &testTensor{
		dtype:       t.dtype,
		elementSize: t.elementSize,
		data:        newData,
		shape:       newShape,
	}
}

func (t *testTensor) SetRows(ctx ml.Context, src ml.Tensor, idxs ml.Tensor) ml.Tensor {
	dst := t
	srcTensor := src.(*testTensor)
	idxTensor := idxs.(*testTensor)

	shapeTo4D := func(shape []int) [4]int {
		out := [4]int{1, 1, 1, 1}
		for i := 0; i < len(shape) && i < 4; i++ {
			out[i] = shape[i]
		}
		return out
	}

	computeStrides := func(shape [4]int) [4]int {
		out := [4]int{1, 1, 1, 1}
		for i := 1; i < 4; i++ {
			out[i] = out[i-1] * shape[i-1]
		}
		return out
	}

	dstShape4D := shapeTo4D(dst.shape)
	srcShape4D := shapeTo4D(srcTensor.shape)
	idxShape4D := shapeTo4D(idxTensor.shape)

	if dstShape4D[0] != srcShape4D[0] || dstShape4D[2] != srcShape4D[2] || dstShape4D[3] != srcShape4D[3] {
		panic("SetRows requires matching tensor shapes")
	}

	if srcShape4D[1] != idxShape4D[0] {
		panic("SetRows rows/index mismatch")
	}

	if srcShape4D[2]%idxShape4D[1] != 0 || srcShape4D[3]%idxShape4D[2] != 0 {
		panic("SetRows cannot broadcast indices")
	}

	if idxShape4D[3] != 1 {
		panic("SetRows expects 1D or 2D index tensors")
	}

	dstStride := computeStrides(dstShape4D)
	srcStride := computeStrides(srcShape4D)
	idxStride := computeStrides(idxShape4D)

	numColumns := srcShape4D[0]
	numRows := srcShape4D[1]

	for dim3Index := range dstShape4D[3] {
		for dim2Index := range dstShape4D[2] {
			idxDim2 := 0
			idxDim3 := 0
			if idxShape4D[1] > 0 {
				idxDim2 = dim2Index % idxShape4D[1]
			}
			if idxShape4D[2] > 0 {
				idxDim3 = dim3Index % idxShape4D[2]
			}

			idxBase := idxDim3*idxStride[2] + idxDim2*idxStride[1]
			srcBase := dim3Index*srcStride[3] + dim2Index*srcStride[2]
			dstBase := dim3Index*dstStride[3] + dim2Index*dstStride[2]

			for row := range numRows {
				idx := int(idxTensor.data[idxBase+row*idxStride[0]])
				if idx < 0 || idx >= dstShape4D[1] {
					panic("SetRows index out of range")
				}

				srcOffset := srcBase + row*srcStride[1]
				dstOffset := dstBase + idx*dstStride[1]

				copy(dst.data[dstOffset:dstOffset+numColumns], srcTensor.data[srcOffset:srcOffset+numColumns])
			}
		}
	}

	return dst
}

func (t *testTensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
	copy(t2.(*testTensor).data, t.data)
	return nil
}
